{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Registration Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import SimpleITK as sitk\n",
    "\n",
    "print(sitk.Version())\n",
    "from myshow import myshow\n",
    "\n",
    "# Download data to work on\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata\n",
    "\n",
    "OUTPUT_DIR = \"Output\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This section of the Visible Human Male is about 1.5GB. To expedite processing and registration we crop the region of interest, and reduce the resolution. Take note that the physical space is maintained through these operations. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_rgb = sitk.ReadImage(fdata(\"vm_head_rgb.mha\"))\n",
    "fixed_rgb = fixed_rgb[735:1330, 204:975, :]\n",
    "fixed_rgb = sitk.BinShrink(fixed_rgb, [3, 3, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "moving = sitk.ReadImage(fdata(\"vm_head_mri.mha\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(moving)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Segment blue ice\n",
    "seeds = [[10, 10, 10]]\n",
    "fixed_mask = sitk.VectorConfidenceConnected(\n",
    "    fixed_rgb,\n",
    "    seedList=seeds,\n",
    "    initialNeighborhoodRadius=5,\n",
    "    numberOfIterations=4,\n",
    "    multiplier=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Invert the segment and choose largest component\n",
    "fixed_mask = sitk.RelabelComponent(sitk.ConnectedComponent(fixed_mask == 0)) == 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(sitk.Mask(fixed_rgb, fixed_mask));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pick red channel\n",
    "fixed = sitk.VectorIndexSelectionCast(fixed_rgb, 0)\n",
    "\n",
    "fixed = sitk.Cast(fixed, sitk.sitkFloat32)\n",
    "moving = sitk.Cast(moving, sitk.sitkFloat32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initialTransform = sitk.Euler3DTransform()\n",
    "initialTransform = sitk.CenteredTransformInitializer(\n",
    "    sitk.Cast(fixed_mask, moving.GetPixelID()),\n",
    "    moving,\n",
    "    initialTransform,\n",
    "    sitk.CenteredTransformInitializerFilter.MOMENTS,\n",
    ")\n",
    "print(initialTransform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def command_iteration(method):\n",
    "    print(\n",
    "        f\"{method.GetOptimizerIteration()} = {method.GetMetricValue()} : {method.GetOptimizerPosition()}\",\n",
    "        end=\"\\n\",\n",
    "    )\n",
    "    sys.stdout.flush();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tx = initialTransform\n",
    "R = sitk.ImageRegistrationMethod()\n",
    "R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)\n",
    "R.SetOptimizerAsGradientDescentLineSearch(learningRate=1, numberOfIterations=100)\n",
    "R.SetOptimizerScalesFromIndexShift()\n",
    "R.SetShrinkFactorsPerLevel([4, 2, 1])\n",
    "R.SetSmoothingSigmasPerLevel([8, 4, 2])\n",
    "R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()\n",
    "R.SetMetricSamplingStrategy(R.RANDOM)\n",
    "# specifying a seed eliminates registration variability associated with the\n",
    "# random sampling\n",
    "R.SetMetricSamplingPercentage(percentage=0.1, seed=42)\n",
    "R.SetInitialTransform(tx)\n",
    "R.SetInterpolator(sitk.sitkLinear)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "R.RemoveAllCommands()\n",
    "R.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(R))\n",
    "outTx = R.Execute(\n",
    "    sitk.Cast(fixed, sitk.sitkFloat32), sitk.Cast(moving, sitk.sitkFloat32)\n",
    ")\n",
    "\n",
    "print(\"-------\")\n",
    "print(tx)\n",
    "print(f\"Optimizer stop condition: {R.GetOptimizerStopConditionDescription()}\")\n",
    "print(f\" Iteration: {R.GetOptimizerIteration()}\")\n",
    "print(f\" Metric value: {R.GetMetricValue()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tx = sitk.CompositeTransform([initialTransform, sitk.AffineTransform(3)])\n",
    "\n",
    "R.SetOptimizerAsGradientDescentLineSearch(learningRate=1, numberOfIterations=100)\n",
    "R.SetOptimizerScalesFromIndexShift()\n",
    "R.SetShrinkFactorsPerLevel([2, 1])\n",
    "R.SetSmoothingSigmasPerLevel([4, 1])\n",
    "R.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()\n",
    "R.SetInitialTransform(tx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outTx = R.Execute(\n",
    "    sitk.Cast(fixed, sitk.sitkFloat32), sitk.Cast(moving, sitk.sitkFloat32)\n",
    ")\n",
    "R.GetOptimizerStopConditionDescription()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resample = sitk.ResampleImageFilter()\n",
    "resample.SetReferenceImage(fixed_rgb)\n",
    "resample.SetInterpolator(sitk.sitkBSpline3)\n",
    "resample.SetTransform(outTx)\n",
    "resample.AddCommand(\n",
    "    sitk.sitkProgressEvent,\n",
    "    lambda: print(f\"\\rProgress: {100*resample.GetProgress():03.1f}%...\", end=\"\"),\n",
    ")\n",
    "resample.AddCommand(sitk.sitkProgressEvent, lambda: sys.stdout.flush())\n",
    "resample.AddCommand(sitk.sitkEndEvent, lambda: print(\"Done\"))\n",
    "out = resample.Execute(moving)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_rgb = sitk.Cast(\n",
    "    sitk.Compose([sitk.RescaleIntensity(out)] * 3), sitk.sitkVectorUInt8\n",
    ")\n",
    "vis_xy = sitk.CheckerBoard(fixed_rgb, out_rgb, checkerPattern=[8, 8, 1])\n",
    "vis_xz = sitk.CheckerBoard(fixed_rgb, out_rgb, checkerPattern=[8, 1, 8])\n",
    "vis_xz = sitk.PermuteAxes(vis_xz, [0, 2, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(vis_xz, dpi=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "sitk.WriteImage(out, os.path.join(OUTPUT_DIR, \"example_registration.mha\"))\n",
    "sitk.WriteImage(vis_xy, os.path.join(OUTPUT_DIR, \"example_registration_xy.mha\"))\n",
    "sitk.WriteImage(vis_xz, os.path.join(OUTPUT_DIR, \"example_registration_xz.mha\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
